{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tuning a masked language model (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]\n",
    "!pip install accelerate\n",
    "# To run the training on TPU, you will need to uncomment the following line:\n",
    "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will need to setup git, adapt your email and name in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForMaskedLM\n",
    "\n",
    "model_checkpoint = \"distilbert-base-uncased\"\n",
    "model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>>> DistilBERT number of parameters: 67M'\n",
       "'>>> BERT number of parameters: 110M'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "distilbert_num_parameters = model.num_parameters() / 1_000_000\n",
    "print(f\"'>>> DistilBERT number of parameters: {round(distilbert_num_parameters)}M'\")\n",
    "print(f\"'>>> BERT number of parameters: 110M'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"This is a great [MASK].\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>>> This is a great deal.'\n",
       "'>>> This is a great success.'\n",
       "'>>> This is a great adventure.'\n",
       "'>>> This is a great idea.'\n",
       "'>>> This is a great feat.'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "inputs = tokenizer(text, return_tensors=\"pt\")\n",
    "token_logits = model(**inputs).logits\n",
    "# Find the location of [MASK] and extract its logits\n",
    "mask_token_index = torch.where(inputs[\"input_ids\"] == tokenizer.mask_token_id)[1]\n",
    "mask_token_logits = token_logits[0, mask_token_index, :]\n",
    "# Pick the [MASK] candidates with the highest logits\n",
    "top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()\n",
    "\n",
    "for token in top_5_tokens:\n",
    "    print(f\"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 25000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 25000\n",
       "    })\n",
       "    unsupervised: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 50000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "imdb_dataset = load_dataset(\"imdb\")\n",
    "imdb_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "'>>> Review: This is your typical Priyadarshan movie--a bunch of loony characters out on some silly mission. His signature climax has the entire cast of the film coming together and fighting each other in some crazy moshpit over hidden money. Whether it is a winning lottery ticket in Malamaal Weekly, black money in Hera Pheri, \"kodokoo\" in Phir Hera Pheri, etc., etc., the director is becoming ridiculously predictable. Don\\'t get me wrong; as clichéd and preposterous his movies may be, I usually end up enjoying the comedy. However, in most his previous movies there has actually been some good humor, (Hungama and Hera Pheri being noteworthy ones). Now, the hilarity of his films is fading as he is using the same formula over and over again.<br /><br />Songs are good. Tanushree Datta looks awesome. Rajpal Yadav is irritating, and Tusshar is not a whole lot better. Kunal Khemu is OK, and Sharman Joshi is the best.'\n",
       "'>>> Label: 0'\n",
       "\n",
       "'>>> Review: Okay, the story makes no sense, the characters lack any dimensionally, the best dialogue is ad-libs about the low quality of movie, the cinematography is dismal, and only editing saves a bit of the muddle, but Sam\" Peckinpah directed the film. Somehow, his direction is not enough. For those who appreciate Peckinpah and his great work, this movie is a disappointment. Even a great cast cannot redeem the time the viewer wastes with this minimal effort.<br /><br />The proper response to the movie is the contempt that the director San Peckinpah, James Caan, Robert Duvall, Burt Young, Bo Hopkins, Arthur Hill, and even Gig Young bring to their work. Watch the great Peckinpah films. Skip this mess.'\n",
       "'>>> Label: 0'\n",
       "\n",
       "'>>> Review: I saw this movie at the theaters when I was about 6 or 7 years old. I loved it then, and have recently come to own a VHS version. <br /><br />My 4 and 6 year old children love this movie and have been asking again and again to watch it. <br /><br />I have enjoyed watching it again too. Though I have to admit it is not as good on a little TV.<br /><br />I do not have older children so I do not know what they would think of it. <br /><br />The songs are very cute. My daughter keeps singing them over and over.<br /><br />Hope this helps.'\n",
       "'>>> Label: 1'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample = imdb_dataset[\"train\"].shuffle(seed=42).select(range(3))\n",
    "\n",
    "for row in sample:\n",
    "    print(f\"\\n'>>> Review: {row['text']}'\")\n",
    "    print(f\"'>>> Label: {row['label']}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'word_ids'],\n",
       "        num_rows: 25000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'word_ids'],\n",
       "        num_rows: 25000\n",
       "    })\n",
       "    unsupervised: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'word_ids'],\n",
       "        num_rows: 50000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def tokenize_function(examples):\n",
    "    result = tokenizer(examples[\"text\"])\n",
    "    if tokenizer.is_fast:\n",
    "        result[\"word_ids\"] = [result.word_ids(i) for i in range(len(result[\"input_ids\"]))]\n",
    "    return result\n",
    "\n",
    "\n",
    "# Use batched=True to activate fast multithreading!\n",
    "tokenized_datasets = imdb_dataset.map(\n",
    "    tokenize_function, batched=True, remove_columns=[\"text\", \"label\"]\n",
    ")\n",
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "512"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.model_max_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chunk_size = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>>> Review 0 length: 200'\n",
       "'>>> Review 1 length: 559'\n",
       "'>>> Review 2 length: 192'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Slicing produces a list of lists for each feature\n",
    "tokenized_samples = tokenized_datasets[\"train\"][:3]\n",
    "\n",
    "for idx, sample in enumerate(tokenized_samples[\"input_ids\"]):\n",
    "    print(f\"'>>> Review {idx} length: {len(sample)}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>>> Concatenated reviews length: 951'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concatenated_examples = {\n",
    "    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()\n",
    "}\n",
    "total_length = len(concatenated_examples[\"input_ids\"])\n",
    "print(f\"'>>> Concatenated reviews length: {total_length}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>>> Chunk length: 128'\n",
       "'>>> Chunk length: 128'\n",
       "'>>> Chunk length: 128'\n",
       "'>>> Chunk length: 128'\n",
       "'>>> Chunk length: 128'\n",
       "'>>> Chunk length: 128'\n",
       "'>>> Chunk length: 128'\n",
       "'>>> Chunk length: 55'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chunks = {\n",
    "    k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]\n",
    "    for k, t in concatenated_examples.items()\n",
    "}\n",
    "\n",
    "for chunk in chunks[\"input_ids\"]:\n",
    "    print(f\"'>>> Chunk length: {len(chunk)}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def group_texts(examples):\n",
    "    # Concatenate all texts\n",
    "    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
    "    # Compute length of concatenated texts\n",
    "    total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
    "    # We drop the last chunk if it's smaller than chunk_size\n",
    "    total_length = (total_length // chunk_size) * chunk_size\n",
    "    # Split by chunks of max_len\n",
    "    result = {\n",
    "        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]\n",
    "        for k, t in concatenated_examples.items()\n",
    "    }\n",
    "    # Create a new labels column\n",
    "    result[\"labels\"] = result[\"input_ids\"].copy()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n",
       "        num_rows: 61289\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n",
       "        num_rows: 59905\n",
       "    })\n",
       "    unsupervised: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n",
       "        num_rows: 122963\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lm_datasets = tokenized_datasets.map(group_texts, batched=True)\n",
    "lm_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\".... at.......... high. a classic line : inspector : i'm here to sack one of your teachers. student : welcome to bromwell high. i expect that many adults of my age think that bromwell high is far fetched. what a pity that it isn't! [SEP] [CLS] homelessness ( or houselessness as george carlin stated ) has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school, work, or vote for the matter. most people think of the homeless\""
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(lm_datasets[\"train\"][1][\"input_ids\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = [lm_datasets[\"train\"][i] for i in range(2)]\n",
    "for sample in samples:\n",
    "    _ = sample.pop(\"word_ids\")\n",
    "\n",
    "for chunk in data_collator(samples)[\"input_ids\"]:\n",
    "    print(f\"\\n'>>> {tokenizer.decode(chunk)}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import numpy as np\n",
    "\n",
    "from transformers import default_data_collator\n",
    "\n",
    "wwm_probability = 0.2\n",
    "\n",
    "\n",
    "def whole_word_masking_data_collator(features):\n",
    "    for feature in features:\n",
    "        word_ids = feature.pop(\"word_ids\")\n",
    "\n",
    "        # Create a map between words and corresponding token indices\n",
    "        mapping = collections.defaultdict(list)\n",
    "        current_word_index = -1\n",
    "        current_word = None\n",
    "        for idx, word_id in enumerate(word_ids):\n",
    "            if word_id is not None:\n",
    "                if word_id != current_word:\n",
    "                    current_word = word_id\n",
    "                    current_word_index += 1\n",
    "                mapping[current_word_index].append(idx)\n",
    "\n",
    "        # Randomly mask words\n",
    "        mask = np.random.binomial(1, wwm_probability, (len(mapping),))\n",
    "        input_ids = feature[\"input_ids\"]\n",
    "        labels = feature[\"labels\"]\n",
    "        new_labels = [-100] * len(labels)\n",
    "        for word_id in np.where(mask)[0]:\n",
    "            word_id = word_id.item()\n",
    "            for idx in mapping[word_id]:\n",
    "                new_labels[idx] = labels[idx]\n",
    "                input_ids[idx] = tokenizer.mask_token_id\n",
    "        feature[\"labels\"] = new_labels\n",
    "\n",
    "    return default_data_collator(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>>> [CLS] bromwell high is a cartoon comedy [MASK] it ran at the same time as some other programs about school life, such as \" teachers \". my 35 years in the teaching profession lead me to believe that bromwell high\\'s satire is much closer to reality than is \" teachers \". the scramble to survive financially, the insightful students who can see right through their pathetic teachers\\'pomp, the pettiness of the whole situation, all remind me of the schools i knew and their students. when i saw the episode in which a student repeatedly tried to burn down the school, i immediately recalled.....'\n",
       "\n",
       "'>>> .... [MASK] [MASK] [MASK] [MASK]....... high. a classic line : inspector : i\\'m here to sack one of your teachers. student : welcome to bromwell high. i expect that many adults of my age think that bromwell high is far fetched. what a pity that it isn\\'t! [SEP] [CLS] homelessness ( or houselessness as george carlin stated ) has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school, work, or vote for the matter. most people think of the homeless'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples = [lm_datasets[\"train\"][i] for i in range(2)]\n",
    "batch = whole_word_masking_data_collator(samples)\n",
    "\n",
    "for chunk in batch[\"input_ids\"]:\n",
    "    print(f\"\\n'>>> {tokenizer.decode(chunk)}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n",
       "        num_rows: 10000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['attention_mask', 'input_ids', 'labels', 'word_ids'],\n",
       "        num_rows: 1000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_size = 10_000\n",
    "test_size = int(0.1 * train_size)\n",
    "\n",
    "downsampled_dataset = lm_datasets[\"train\"].train_test_split(\n",
    "    train_size=train_size, test_size=test_size, seed=42\n",
    ")\n",
    "downsampled_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "batch_size = 64\n",
    "# Show the training loss with every epoch\n",
    "logging_steps = len(downsampled_dataset[\"train\"]) // batch_size\n",
    "model_name = model_checkpoint.split(\"/\")[-1]\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"{model_name}-finetuned-imdb\",\n",
    "    overwrite_output_dir=True,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    weight_decay=0.01,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    push_to_hub=True,\n",
    "    fp16=True,\n",
    "    logging_steps=logging_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=downsampled_dataset[\"train\"],\n",
    "    eval_dataset=downsampled_dataset[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       ">>> Perplexity: 21.75"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "eval_results = trainer.evaluate()\n",
    "print(f\">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       ">>> Perplexity: 11.32"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_results = trainer.evaluate()\n",
    "print(f\">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def insert_random_mask(batch):\n",
    "    features = [dict(zip(batch, t)) for t in zip(*batch.values())]\n",
    "    masked_inputs = data_collator(features)\n",
    "    # Create a new \"masked\" column for each column in the dataset\n",
    "    return {\"masked_\" + k: v.numpy() for k, v in masked_inputs.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "downsampled_dataset = downsampled_dataset.remove_columns([\"word_ids\"])\n",
    "eval_dataset = downsampled_dataset[\"test\"].map(\n",
    "    insert_random_mask,\n",
    "    batched=True,\n",
    "    remove_columns=downsampled_dataset[\"test\"].column_names,\n",
    ")\n",
    "eval_dataset = eval_dataset.rename_columns(\n",
    "    {\n",
    "        \"masked_input_ids\": \"input_ids\",\n",
    "        \"masked_attention_mask\": \"attention_mask\",\n",
    "        \"masked_labels\": \"labels\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator\n",
    "\n",
    "batch_size = 64\n",
    "train_dataloader = DataLoader(\n",
    "    downsampled_dataset[\"train\"],\n",
    "    shuffle=True,\n",
    "    batch_size=batch_size,\n",
    "    collate_fn=data_collator,\n",
    ")\n",
    "eval_dataloader = DataLoader(\n",
    "    eval_dataset, batch_size=batch_size, collate_fn=default_data_collator\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=5e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "\n",
    "accelerator = Accelerator()\n",
    "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
    "    model, optimizer, train_dataloader, eval_dataloader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "num_train_epochs = 3\n",
    "num_update_steps_per_epoch = len(train_dataloader)\n",
    "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n",
    "\n",
    "lr_scheduler = get_scheduler(\n",
    "    \"linear\",\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=num_training_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'lewtun/distilbert-base-uncased-finetuned-imdb-accelerate'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from huggingface_hub import get_full_repo_name\n",
    "\n",
    "model_name = \"distilbert-base-uncased-finetuned-imdb-accelerate\"\n",
    "repo_name = get_full_repo_name(model_name)\n",
    "repo_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import Repository\n",
    "\n",
    "output_dir = model_name\n",
    "repo = Repository(output_dir, clone_from=repo_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       ">>> Epoch 0: Perplexity: 11.397545307900472\n",
       ">>> Epoch 1: Perplexity: 10.904909330983092\n",
       ">>> Epoch 2: Perplexity: 10.729503505340409"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import math\n",
    "\n",
    "progress_bar = tqdm(range(num_training_steps))\n",
    "\n",
    "for epoch in range(num_train_epochs):\n",
    "    # Training\n",
    "    model.train()\n",
    "    for batch in train_dataloader:\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        progress_bar.update(1)\n",
    "\n",
    "    # Evaluation\n",
    "    model.eval()\n",
    "    losses = []\n",
    "    for step, batch in enumerate(eval_dataloader):\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "\n",
    "        loss = outputs.loss\n",
    "        losses.append(accelerator.gather(loss.repeat(batch_size)))\n",
    "\n",
    "    losses = torch.cat(losses)\n",
    "    losses = losses[: len(eval_dataset)]\n",
    "    try:\n",
    "        perplexity = math.exp(torch.mean(losses))\n",
    "    except OverflowError:\n",
    "        perplexity = float(\"inf\")\n",
    "\n",
    "    print(f\">>> Epoch {epoch}: Perplexity: {perplexity}\")\n",
    "\n",
    "    # Save and upload\n",
    "    accelerator.wait_for_everyone()\n",
    "    unwrapped_model = accelerator.unwrap_model(model)\n",
    "    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n",
    "    if accelerator.is_main_process:\n",
    "        tokenizer.save_pretrained(output_dir)\n",
    "        repo.push_to_hub(\n",
    "            commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "mask_filler = pipeline(\n",
    "    \"fill-mask\", model=\"huggingface-course/distilbert-base-uncased-finetuned-imdb\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>>> this is a great movie.'\n",
       "'>>> this is a great film.'\n",
       "'>>> this is a great story.'\n",
       "'>>> this is a great movies.'\n",
       "'>>> this is a great character.'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds = mask_filler(text)\n",
    "\n",
    "for pred in preds:\n",
    "    print(f\">>> {pred['sequence']}\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Fine-tuning a masked language model (PyTorch)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
